Our modeling goal is to predict the speaker of each line of dialogue.
https://juliasilge.com/blog/last-airbender/
library(tidyverse)
avatar_raw <- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-08-11/avatar.csv")
── Column specification ──────────────────────────────────────
cols(
id = col_double(),
book = col_character(),
book_num = col_double(),
chapter = col_character(),
chapter_num = col_double(),
character = col_character(),
full_text = col_character(),
character_words = col_character(),
writer = col_character(),
director = col_character(),
imdb_rating = col_double()
)
avatar_raw %>%
count(character, sort = TRUE)
Rows with Scene Description are not dialogue; the main character Aang speaks the most lines overall. How does this change through the three “books” of the show?
library(tidytext)
avatar_raw %>%
filter(!is.na(character_words)) %>%
mutate(
book = fct_inorder(book),
character = fct_lump_n(character, 10)
) %>%
count(book, character) %>%
mutate(character = reorder_within(character, n, book)) %>%
ggplot(aes(n, character, fill = book)) +
geom_col(show.legend = FALSE) +
facet_wrap(~book, scales = "free") +
scale_y_reordered() +
labs(y = NULL)
Let’s create a dataset for our modeling question, and look at a few example lines.
avatar <- avatar_raw %>%
filter(!is.na(character_words)) %>%
mutate(aang = if_else(character == "Aang", "Aang", "Other")) %>%
select(aang, book, text = character_words)
avatar %>%
filter(aang == "Aang") %>%
sample_n(10) %>%
pull(text)
[1] "Bumi!"
[2] "I guess."
[3] "Jet, it's me, Aang! You don't have to do this!"
[4] "I never saw Gyatso again. Next thing I knew, I was waking up in your arms after you found me in the iceberg."
[5] "Ran and Shaw? There are two of them?"
[6] "Hey! Over here!"
[7] "Excuse me, I don't mean to bother you, but my friend's sick and we're on kind of a tight schedule. Wait! But I'm a great bridge between your world and mine! I know Hei Bai! We're close personal friends! Heeey! My name's Aang. I'm the Avatar."
[8] "I gotta say, Sokka, you continue to impress me with your ideas."
[9] "Whaaaahh!"
[10] "Hello, Spirit? Can you hear me? This is the Avatar speaking. I'm here to try to help stuff."
What are the highest log odds words from Aang and other speakers?
avatar_lo %>%
group_by(aang) %>%
top_n(15) %>%
ungroup() %>%
mutate(word = reorder(word, log_odds_weighted)) %>%
ggplot(aes(log_odds_weighted, word, fill = aang)) +
geom_col(alpha = 0.8, show.legend = FALSE) +
facet_wrap(~aang, scales = "free") +
labs(y = NULL)
Selecting by log_odds_weighted
These words make sense, but the counts are probably too low to build a good model with. Instead, let’s try using text features like the number of punctuation characters, number of pronons, and so forth.
library(textfeatures)
Error in library(textfeatures) :
there is no package called ‘textfeatures’
See more about previous countings here: https://textfeatures.mikewk.com/reference/count_functions.html
We can start by loading the tidymodels metapackage, and splitting our data into training and testing sets.
library(tidymodels)
set.seed(123)
avatar_split <- initial_split(avatar, strata = aang)
avatar_train <- training(avatar_split)
avatar_test <- testing(avatar_split)
let’s create cross-validation resamples of the training data, to evaluate our models.
set.seed(234)
avatar_folds <- vfold_cv(avatar_train, strata = aang)
avatar_folds
# 10-fold cross-validation using stratification
let’s preprocess our data to get it ready for modeling.
avatar_prep <- prep(avatar_rec)
avatar_prep <- prep(avatar_rec)
avatar_prep
Data Recipe
Inputs:
Training data contained 7494 data points and no missing data.
Operations:
Down-sampling based on aang [trained]
Text feature extraction for text [trained]
Zero variance filter removed
no non-missing arguments to max; returning -Inf
15 items [trained]
Centering and scaling for 12 items [trained]
juice(avatar_prep)
Let’s walk through the steps in this recipe.
First, we must tell the recipe() what our model is going to be (using a formula here) and what data we are using. Next, we downsample for our predictor, since there are many more lines spoken by characters other than Aang than by Aang. We create the text features using a step from the textrecipes package. Then we remove zero-variance variables, which includes variables like the text features about URLs and hashtags in this case. Finally, we center and scale the predictors because of the specific kind of model we want to try out.
We’re mostly going to use this recipe in a workflow() so we don’t need to stress too much about whether to prep() or not. Since we are going to compute variable importance, we will need to come back to juice(avatar_prep).
Let’s compare two different models, a random forest model and a support vector machine model. We start by creating the model specifications.
rf_spec <- rand_forest(trees = 1000) %>%
set_engine("ranger") %>%
set_mode("classification")
rf_spec
Random Forest Model Specification (classification)
Main Arguments:
trees = 1000
Computational engine: ranger
svm_spec <- svm_rbf(cost = 0.5) %>%
set_engine("kernlab") %>%
set_mode("classification")
svm_spec
Radial Basis Function Support Vector Machine Specification (classification)
Main Arguments:
cost = 0.5
Computational engine: kernlab
Next let’s start putting together a tidymodels workflow(), a helper object to help manage modeling pipelines with pieces that fit together like Lego blocks. Notice that there is no model yet: Model: None.
avatar_wf <- workflow() %>%
add_recipe(avatar_rec)
avatar_wf
══ Workflow ═════════════════════════════════════════════════
Preprocessor: Recipe
Model: None
── Preprocessor ─────────────────────────────────────────────
4 Recipe Steps
● step_downsample()
● step_textfeature()
● step_zv()
● step_normalize()
Now we can add a model, and the fit to each of the resamples. First, we can fit the random forest model.
doParallel::registerDoParallel()
set.seed(1234)
rf_rs <- avatar_wf %>%
add_model(rf_spec) %>%
fit_resamples(
resamples = avatar_folds,
metrics = metric_set(roc_auc, accuracy, sens, spec),
control = control_grid(save_pred = TRUE)
)
Second, we can fit the support vector machine model.
set.seed(2345)
svm_rs <- avatar_wf %>%
add_model(svm_spec) %>%
fit_resamples(
resamples = avatar_folds,
metrics = metric_set(roc_auc, accuracy, sens, spec),
control = control_grid(save_pred = TRUE)
)
Attaching package: ‘kernlab’
The following object is masked from ‘package:scales’:
alpha
The following object is masked from ‘package:purrr’:
cross
The following object is masked from ‘package:ggplot2’:
alpha
We have fit each of our candidate models to our resampled training set.
collect_metrics(rf_rs)
conf_mat_resampled(rf_rs)
collect_metrics(svm_rs)
conf_mat_resampled(svm_rs)
Different, but not really better! The SVM model is better able to identify the positive cases but at the expense of the negative cases. Overall, we definitely see that this is a hard problem that we barely are able to have any predictive ability for.
Let’s say we are more interested in detecting Aang’s lines, even at the expense of the false positives.
svm_rs %>%
collect_predictions() %>%
group_by(id) %>%
roc_curve(aang, .pred_Aang) %>%
ggplot(aes(1 - specificity, sensitivity, color = id)) +
geom_abline(lty = 2, color = "black", size = 1) +
geom_path(show.legend = FALSE, alpha = 0.6, size = 0.5) +
coord_equal()
This plot highlights how this model is barely doing better than guessing.
Keeping in mind the realities of our model performance, let’s talk about how to compute variable importance for a model like an SVM, which does not have information within it about variable importance like a linear model or a tree-based model. In this case, we can use a method like permutation of the variables.
These are the text features that are most important globally for whether a line was spoken by Aang or not.
Finally, we can return to the testing data to confirm that our performance is about the same.
avatar_final %>%
collect_predictions() %>%
conf_mat(aang, .pred_class)
Truth
Prediction Aang Other
Aang 261 1004
Other 188 1045